import numpy as np
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.nn.functional import one_hot
from .common_methods import MetricType

def cross_entropy(logit, target):
    # target must be one-hot format!!
    prob_logit = F.softmax(logit, dim=1)
    target_logit = F.softmax(target, dim=1)
    loss = -(target_logit * prob_logit).sum(dim=1).mean()
    return loss

def network_weight_gaussian_init(net: nn.Module):
    with torch.no_grad():
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.zeros_(m.bias)
            else:
                continue

    return net

def recal_bn(network, xloader, recalbn, device):
    for m in network.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.running_mean.data.fill_(0)
            m.running_var.data.fill_(0)
            m.num_batches_tracked.data.zero_()
            m.momentum = None
    network.train()
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(xloader):
            if i >= recalbn: break
            inputs = inputs.cuda(device=device, non_blocking=True)
            _, _ = network(inputs)
    return network


def get_ntk_n(xloader, networks, mixup_gamma=1e-2, recalbn=0, train_mode=False, num_batch=-1):
    device = torch.cuda.current_device()
    # if recalbn > 0:
    #     network = recal_bn(network, xloader, recalbn, device)
    #     if network_2 is not None:
    #         network_2 = recal_bn(network_2, xloader, recalbn, device)
    ntks = []
    for network in networks:
        network_weight_gaussian_init(network)
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads = [[] for _ in range(len(networks))]
    grads_val = [[] for _ in range(len(networks))]
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits
            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                torch.cuda.empty_cache()
    ######
    grads = [torch.stack(_grads, 0) for _grads in grads]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
    conds = []
    for ntk in ntks:
        # eigenvalues, _ = torch.symeig(ntk)  # ascending
        eigenvalues = torch.linalg.eigvalsh(ntk, UPLO='U')
        # conds.append(np.nan_to_num((eigenvalues[0]).item(), copy=True, nan=100000.0))
        # conds.append(np.nan_to_num((eigenvalues[0] / eigenvalues[-1]).item(), copy=True, nan=100000.0))
        # conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True))
        # conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0] + 500 / torch.log10(eigenvalues[-1])).item(), copy=True))
        conds.append(np.nan_to_num((eigenvalues[-10:].sum() / eigenvalues.sum()).item(), copy=True))
        # conds.append(np.nan_to_num((-1 * (eigenvalues[-1] / eigenvalues[0] / 10 + eigenvalues[-8:].sum() / eigenvalues.sum())).item(), copy=True))
        # conds.append(np.nan_to_num((-1 * (eigenvalues[-8:].sum() / eigenvalues.sum() + 100 / torch.log10(eigenvalues.sum() / len(eigenvalues)))).item(), copy=True))
        # conds.append(np.nan_to_num(((eigenvalues[-1] / eigenvalues[0] / torch.log10(eigenvalues[-1]))).item(), copy=True))
        # conds.append(np.nan_to_num(((eigenvalues[-1] / eigenvalues[0] * (1 + 1 / torch.log10(eigenvalues[-1])))).item(), copy=True))
        # conds.append(np.nan_to_num((-1 * (eigenvalues[-1] / eigenvalues[0] * (1 + math.log(50) / torch.log(eigenvalues[-1])))).item(), copy=True))
        # new_eigenvalues = torch.square(eigenvalues)
        # conds.append(np.nan_to_num((1 - new_eigenvalues[-8:].sum() / new_eigenvalues.sum()).item(), copy=True))
        # conds.append(np.nan_to_num((-1 * (eigenvalues[-1] / eigenvalues[0] * (1 + 1 / torch.log10(eigenvalues[-1])))).item(), copy=True))
        '''
        file_path = "output.txt"
        with open(file_path, "a") as file:
            file.write(str(eigenvalues[-1] / eigenvalues[0]) + '      ' + str(eigenvalues[-1]) + '\n')
        
        
        n = len(eigenvalues)
        new_eigenvalues = torch.zeros(n)
        for i in range(len(eigenvalues)):
            # new_eigenvalues[i] = eigenvalues[i] * (math.factorial(n - 1) // (math.factorial(i) * math.factorial(n - 1 - i)))
            if i < n / 2:
                new_eigenvalues[i] = eigenvalues[i] * (i + 1) / n * 4
            else:
                new_eigenvalues[i] = eigenvalues[i] * (n - i) / n * 4
        conds.append(np.nan_to_num((1 - new_eigenvalues[-8:].sum() / new_eigenvalues.sum()).item(), copy=True))
        '''
    return conds


def get_ntk_partial(xloader, networks, mixup_gamma=1e-2, recalbn=0, train_mode=False, num_batch=-1, k=8):
    device = torch.cuda.current_device()
    # if recalbn > 0:
    #     network = recal_bn(network, xloader, recalbn, device)
    #     if network_2 is not None:
    #         network_2 = recal_bn(network_2, xloader, recalbn, device)
    ntks = []
    for network in networks:
        network_weight_gaussian_init(network)
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads = [[] for _ in range(len(networks))]
    grads_val = [[] for _ in range(len(networks))]
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits
            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                torch.cuda.empty_cache()
    ######
    grads = [torch.stack(_grads, 0) for _grads in grads]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
    conds = []
    for ntk in ntks:
        # eigenvalues, _ = torch.symeig(ntk)  # ascending
        eigenvalues = torch.linalg.eigvalsh(ntk, UPLO='U')
        sum = 0
        for i in range(k):
            sum += eigenvalues[63 - i]
        conds.append(np.nan_to_num((sum / eigenvalues.sum()).item(), copy=True))
    return conds


def compute_ntk_grads(inputs, network):
    grads = []
    network.zero_grad()
    # inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
    logit = network(inputs)
    if isinstance(logit, tuple):
        logit = logit[1]  # 201 networks: return features and logits
    for _idx in range(len(inputs)):
        logit[_idx:_idx + 1].backward(torch.ones_like(logit[_idx:_idx + 1]), retain_graph=True)
        grad = []
        for name, W in network.named_parameters():
            if 'weight' in name and W.grad is not None:
                grad.append(W.grad.view(-1).detach())
        grads.append(torch.cat(grad, -1))
        network.zero_grad()
        torch.cuda.empty_cache()
    return grads


def get_ntk_n_v2(train_loader, valid_loader, networks, metric=MetricType.COND, train_mode=False, as_correlation=False, train_iters=-1, num_batch=-1, verbose=False):
    device = torch.cuda.current_device()
    for network in networks:
        if train_iters > 0:
            slight_train(network, train_loader, train_iters, device)
        if train_mode:
            network.train()
        else:
            network.eval()

    train_grads = [[] for _ in range(len(networks))]
    train_targets = []
    for i, (inputs, targets) in enumerate(train_loader):
        if num_batch > 0 and i >= num_batch:
            break
        inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            train_grads[net_idx].append(compute_ntk_grads(inputs, network))
        train_targets.append(targets.detach())

    ######
    train_grads = [torch.stack(g, 0) for _grads in train_grads for g in _grads]
    if as_correlation:
        train_ntks = [torch.corrcoef(_grads) for _grads in train_grads]
    else:
        train_ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in train_grads]

    if MetricType.require_only_matrix(metric):
        scores = []
        for ntk in train_ntks:
            val = metric(ntk)
            scores.append(val)
        return scores

    num_classes = len(valid_loader.dataset.classes)
    train_targets = torch.concat(train_targets, 0)
    train_targets = one_hot(train_targets, num_classes=num_classes).to(torch.float32).cuda(device=device, non_blocking=True)

    if metric is MetricType.LGA:
        scores = []
        for ntk in train_ntks:
            val = metric(ntk, train_targets)
            scores.append(-1 * val)
        return scores

    val_grads = [[] for _ in range(len(networks))]
    val_targets = []
    for i, (inputs, targets) in enumerate(valid_loader):
        if num_batch > 0 and i >= num_batch:
            break
        inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            val_grads[net_idx].append(compute_ntk_grads(inputs, network))
        val_targets.append(targets.detach())

    val_grads = [torch.stack(g, 0) for _grads in val_grads for g in _grads]
    val_ntks = [torch.einsum('nc,mc->nm', [g1, g2]) for g1,g2 in zip(val_grads,train_grads)]

    val_targets = torch.concat(val_targets, 0)
    val_targets = one_hot(val_targets, num_classes=num_classes).to(torch.float32).cuda(device=device, non_blocking=True)

    scores = [-1.0 for _ in range(len(networks))]
    for i in range(len(networks)):
        ntk_tt = train_ntks[i]
        ntk_vt = val_ntks[i]

        inv_labels = torch.linalg.solve(ntk_tt, train_targets)
        prediction = torch.matmul(ntk_vt, inv_labels)

        val = metric(val_targets, prediction)
        scores[i] = val

    return scores


def get_ntk_plus(xloader, vloader, networks, mixup_gamma=1e-2, recalbn=0, train_mode=False, num_batch=-1, num_classes=100):
    device = torch.cuda.current_device()
    ntks = []
    for network in networks:
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads_x = [[] for _ in range(len(networks))]
    targets_x_onehot_mean, targets_y_onehot_mean = [], []
    grads_y = [[] for _ in range(len(networks))]

    # for i, (inputs, targets) in enumerate(xloader):
    #     if num_batch > 0 and i >= num_batch: break
    '''
    for i in range(num_batch):
        inputs = torch.randn((batch_size, 3, image_size, image_size), device=device)
        targets = torch.randint(0, num_classes, (64,))
    '''
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        size = (8,) + inputs.size()[1:]
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_x_onehot_mean.append(targets_onehot_mean)
        targets_x_onehot_mean = torch.cat(targets_x_onehot_mean, 0)

        # inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)

            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits

            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads_x[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                torch.cuda.empty_cache()
    grads_x = [torch.stack(_grads, 0) for _grads in grads_x]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads_x]
    score_1, score_2 = [], []
    for ntk in ntks:
        eigenvalues = torch.linalg.eigvalsh(ntk, UPLO='U')
        # score_1.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0] * (1 + 1 / torch.log10(eigenvalues[-1]))).item(), copy=True))
        score_1.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True))
        score_2.append(np.nan_to_num((eigenvalues[-8:].sum() / eigenvalues.sum()).item(), copy=True))
        

    # Val / Test set
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        input_1 = inputs[32:]
        size = (32,) + inputs.size()[1:]
        input_2 = torch.randn(size)
        input_mix = input_1 + mixup_gamma * input_2
        inputs = torch.cat((input_1, input_mix), dim=0)
        inputs = inputs.cuda(device=device, non_blocking=True)
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_y_onehot_mean.append(targets_onehot_mean)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits

            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads_y[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                torch.cuda.empty_cache()
    grads_y = [torch.stack(_grads, 0) for _grads in grads_y]
    score_3, score_4 = [], []
    targets_y_onehot_mean = torch.cat(targets_y_onehot_mean, 0)
    for net_idx in range(len(networks)):
        # _ntk_yx = torch.einsum('nc,mc->nm', [grads_y, grads_x])
        _ntk_yx = [torch.einsum('nc,mc->nm', [_grads_y, _grads_x]) for _grads_y, _grads_x in zip(grads_y, grads_x)]
        PY = torch.einsum('jk,kl,lm->jm', _ntk_yx[0], torch.inverse(ntks[0]), targets_x_onehot_mean) 
        score_4.append(((PY[32:] - targets_y_onehot_mean[32:])**2).sum(1).mean(0).item())
        # score_3.append(-1 * cross_entropy(PY[32:], targets_y_onehot_mean[32:]).item())
        # score_3.append(((PY[32:] - targets_y_onehot_mean[32:])**2).sum(1).mean(0).item())
        # score_3.append(-1 * cross_entropy(PY[32:], PY[:32, :]).item())
        score_3.append(-1 * ((PY[32:] - PY[:32, :])**2).sum(1).mean(0).item())
        # score_3.append(np.nan_to_num((torch.sum(torch.abs(PY[:8, :] - PY[8:])) / 8 * -1).item(), copy=True))
    ######

    return score_2, score_4


def get_ntk_lga(xloader, networks, recalbn=0, train_mode=False, num_batch=-1):
    device = torch.cuda.current_device()
    # if recalbn > 0:
    #     network = recal_bn(network, xloader, recalbn, device)
    #     if network_2 is not None:
    #         network_2 = recal_bn(network_2, xloader, recalbn, device)
    ntks = []
    targets_x_onehot_mean = []
    for network in networks:
        network_weight_gaussian_init(network)
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads = [[] for _ in range(len(networks))]
    grads_val = [[] for _ in range(len(networks))]
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_x_onehot_mean.append(targets_onehot_mean)
        targets_x_onehot_mean = torch.cat(targets_x_onehot_mean, 0)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits
            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                torch.cuda.empty_cache()
    ######
    grads = [torch.stack(_grads, 0) for _grads in grads]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
    lgas = []
    labels = targets_x_onehot_mean
    for mat in ntks:
        mat_normalized = mat - torch.mean(mat)
        labels_normalized = torch.matmul(labels, labels.T)
        labels_normalized[labels_normalized<1] = -1
        labels_normalized = labels_normalized - torch.mean(labels_normalized)

        score = mat_normalized * labels_normalized / (torch.norm(mat_normalized, 2) * torch.norm(labels_normalized, 2))
        lgas.append(torch.sum(score).item())
    return lgas
    

def get_ntk_fnorm(xloader, networks, recalbn=0, train_mode=False, num_batch=-1):
    device = torch.cuda.current_device()
    # if recalbn > 0:
    #     network = recal_bn(network, xloader, recalbn, device)
    #     if network_2 is not None:
    #         network_2 = recal_bn(network_2, xloader, recalbn, device)
    ntks = []
    for network in networks:
        network_weight_gaussian_init(network)
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads = [[] for _ in range(len(networks))]
    grads_val = [[] for _ in range(len(networks))]
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits
            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                torch.cuda.empty_cache()
    ######
    grads = [torch.stack(_grads, 0) for _grads in grads]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
    fnorms = []
    for ntk in ntks:
        fnorms.append(torch.norm(ntk, p="fro").item())
    return fnorms


def get_ntk_n_comb(xloader, vloader, networks, mixup_gamma=1e-2, recalbn=0, train_mode=False, num_batch=-1, num_classes=100):
    device = torch.cuda.current_device()
    ntks = []
    for network in networks:
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads_x = [[] for _ in range(len(networks))]
    targets_x_onehot_mean, targets_y_onehot_mean = [], []
    grads_y = [[] for _ in range(len(networks))]

    # for i, (inputs, targets) in enumerate(xloader):
    #     if num_batch > 0 and i >= num_batch: break
    '''
    for i in range(num_batch):
        inputs = torch.randn((batch_size, 3, image_size, image_size), device=device)
        targets = torch.randint(0, num_classes, (64,))
    '''
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        size = (8,) + inputs.size()[1:]
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_x_onehot_mean.append(targets_onehot_mean)
        targets_x_onehot_mean = torch.cat(targets_x_onehot_mean, 0)

        # inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)

            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits

            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads_x[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                torch.cuda.empty_cache()
    grads_x = [torch.stack(_grads, 0) for _grads in grads_x]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads_x]
    score_1, score_2 = [], []
    for ntk in ntks:
        eigenvalues = torch.linalg.eigvalsh(ntk, UPLO='U')
        score_1.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0] * (1 + 1 / torch.log10(eigenvalues[-1]))).item(), copy=True))
        # score_1.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True))
        score_2.append(np.nan_to_num((eigenvalues[-8:].sum() / eigenvalues.sum()).item(), copy=True))
        

    # Val / Test set
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        input_1 = inputs[32:]
        size = (32,) + inputs.size()[1:]
        input_2 = torch.randn(size)
        input_mix = input_1 + mixup_gamma * input_2
        inputs = torch.cat((input_1, input_mix), dim=0)
        inputs = inputs.cuda(device=device, non_blocking=True)
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_y_onehot_mean.append(targets_onehot_mean)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits

            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads_y[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                torch.cuda.empty_cache()
    grads_y = [torch.stack(_grads, 0) for _grads in grads_y]
    score_3 = []
    targets_y_onehot_mean = torch.cat(targets_y_onehot_mean, 0)
    for net_idx in range(len(networks)):
        # _ntk_yx = torch.einsum('nc,mc->nm', [grads_y, grads_x])
        _ntk_yx = [torch.einsum('nc,mc->nm', [_grads_y, _grads_x]) for _grads_y, _grads_x in zip(grads_y, grads_x)]
        PY = torch.einsum('jk,kl,lm->jm', _ntk_yx[0], torch.inverse(ntks[0]), targets_x_onehot_mean) 
        # score_3.append(cross_entropy(PY[32:], targets_y_onehot_mean[32:]).item())
        # score_3.append(((PY[32:] - targets_y_onehot_mean[32:])**2).sum(1).mean(0).item())
        # score_3.append(cross_entropy(PY[32:], PY[:32, :]).item())
        score_3.append(-1 * ((PY[32:] - PY[:32, :])**2).sum(1).mean(0).item())
        # score_3.append(np.nan_to_num((torch.sum(torch.abs(PY[:8, :] - PY[8:])) / 8 * -1).item(), copy=True))
    ######

    return score_1, score_2, score_3


def get_ntk_n_zen(xloader, vloader, networks, recalbn=0, train_mode=False, num_batch=-1, mixup_gamma=1e-2, num_classes=100):
    device = torch.cuda.current_device()
    ntks = []
    for network in networks:
        if train_mode:
            network.train()
        else:
            network.eval()
    
    grads_x = [[] for _ in range(len(networks))]
    cellgrads_x = [[] for _ in range(len(networks))]; cellgrads_y = [[] for _ in range(len(networks))]
    ntk_cell_x = []; ntk_cell_yx = []; prediction_mses = []
    targets_x_onehot_mean = []; targets_y_onehot_mean = []
    
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_x_onehot_mean.append(targets_onehot_mean)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits
            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                grad = []
                cellgrad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                        if "cell" in name:
                            cellgrad.append(W.grad.view(-1).detach())
                grads_x[net_idx].append(torch.cat(grad, -1))
                cellgrad = torch.cat(cellgrad, -1) if len(cellgrad) > 0 else torch.Tensor([0]).cuda()
                if len(cellgrads_x[net_idx]) == 0:
                    cellgrads_x[net_idx] = [cellgrad]
                else:
                    cellgrads_x[net_idx].append(cellgrad)
                network.zero_grad()
                torch.cuda.empty_cache()
    targets_x_onehot_mean = torch.cat(targets_x_onehot_mean, 0)
    # cell's NTK #####
    for _i, grads in enumerate(cellgrads_x):
        grads = torch.stack(grads, 0)
        _ntk = torch.einsum('nc,mc->nm', [grads, grads])
        ntk_cell_x.append(_ntk)
        cellgrads_x[_i] = grads
    # NTK cond
    grads_x = [torch.stack(_grads, 0) for _grads in grads_x]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads_x]
    conds_x = []
    score_1, score_2, score_3, score_4 = [], [], [], []
    for ntk in ntks:
        # eigenvalues, _ = torch.symeig(ntk)  # ascending
        eigenvalues = torch.linalg.eigvalsh(ntk, UPLO='U')
        # conds_x.append(np.nan_to_num((eigenvalues[0]).item(), copy=True, nan=100000.0))
        # conds_x.append(np.nan_to_num((eigenvalues[0] / eigenvalues[-1]).item(), copy=True, nan=100000.0))
        # conds_x.append(np.nan_to_num((-1 * eigenvalues[-1] / eigenvalues[0]).item(), copy=True))
        # score_1.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True))
        score_1.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0] * (1 + 1 / torch.log10(eigenvalues[-1]))).item(), copy=True))
        score_2.append(np.nan_to_num((eigenvalues[-8:].sum() / eigenvalues.sum()).item(), copy=True))
        # conds_x.append(np.nan_to_num((-1 * (eigenvalues[-1] / eigenvalues[0] / 10 + eigenvalues[-8:].sum() / eigenvalues.sum())).item(), copy=True))
        # conds_x.append(np.nan_to_num((-1 * (eigenvalues[-8:].sum() / eigenvalues.sum() + 1 / torch.log10(eigenvalues.sum() / len(eigenvalues)))).item(), copy=True))
        # conds_x.append(np.nan_to_num((-1 * (eigenvalues[-1] / eigenvalues[0] + 500 / torch.log10(eigenvalues[-1]))).item(), copy=True))
        # new_eigenvalues = torch.square(eigenvalues)
        # conds_x.append(np.nan_to_num((1 - new_eigenvalues[-8:].sum() / new_eigenvalues.sum()).item(), copy=True))
    
    # Val / Test set
    for i, (inputs, targets) in enumerate(vloader):
        if num_batch > 0 and i >= num_batch: break
        input_1 = inputs[32:]
        size = (32,) + inputs.size()[1:]
        input_2 = torch.randn(size)
        input_mix = input_1 + mixup_gamma * input_2
        # print(input_mix - input_1)
        # print(((input_mix - input_1)**2).sum(1).mean(0).item())
        # print(aaa)
        inputs = torch.cat((input_1, input_mix), dim=0)
        inputs = inputs.cuda(device=device, non_blocking=True)
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_y_onehot_mean.append(targets_onehot_mean)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits
            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                cellgrad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None and "cell" in name:
                        cellgrad.append(W.grad.view(-1).detach())
                cellgrad = torch.cat(cellgrad, -1) if len(cellgrad) > 0 else torch.Tensor([0]).cuda()
                if len(cellgrads_y[net_idx]) == 0:
                    cellgrads_y[net_idx] = [cellgrad]
                else:
                    cellgrads_y[net_idx].append(cellgrad)
                network.zero_grad()
                torch.cuda.empty_cache()
    targets_y_onehot_mean = torch.cat(targets_y_onehot_mean, 0)
    for _i, grads in enumerate(cellgrads_y):
        grads = torch.stack(grads, 0)
        cellgrads_y[_i] = grads
    for net_idx in range(len(networks)):
        if cellgrads_y[net_idx].sum() == 0 or cellgrads_x[net_idx].sum() == 0:
            # bad gradients
            score_3.append(-1)
            score_4.append(-1)
            continue
        _ntk_yx = torch.einsum('nc,mc->nm', [cellgrads_y[net_idx], cellgrads_x[net_idx]])
        PY = torch.einsum('jk,kl,lm->jm', _ntk_yx, torch.inverse(ntk_cell_x[net_idx]), targets_x_onehot_mean)
        score_4.append(((PY[32:] - targets_y_onehot_mean[32:])**2).sum(1).mean(0).item())
        # score_4.append(0)
        # score_4.append(cross_entropy(PY[32:], targets_y_onehot_mean[32:]).item())
        # score_3.append(-1 * ((PY[32:] - PY[:32, :])**2).sum(1).mean(0).item())
        score_3.append(-1 * cross_entropy(PY[:32, :], PY[32:]).item())
        # score_3.append(-1 * ((PY[32:] - PY[:32, :])**2).sum(1).mean(0).item())
        # score_3.append(1 / math.log10(((PY[32:] - PY[:32, :])**2).sum(1).mean(0).item()))
        # score_3.append(0)
        # score_3.append(1 / cross_entropy(PY[32:], PY[:32, :]).item())
        # score_3.append(np.nan_to_num((torch.sum(torch.abs(PY[:32, :] - PY[32:])) / 8).item(), copy=True))
    ######

    return score_2, score_3, score_4
